package com.aloha.middleware.db.router;

import com.aloha.middleware.db.router.annotation.DBRouter;
import com.aloha.middleware.db.router.config.DBRouterProperties;
import com.aloha.middleware.db.router.strategy.IDBRouterStrategy;
import com.aloha.middleware.db.router.utils.DBContextHolder;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import java.lang.reflect.Field;

/**
 * @author DaiZhiHeng
 * @description 数据路由切面
 * @date 2023/7/9 11:17
 */
@Aspect
public class DBRouterAspect {
    private final Logger log = LoggerFactory.getLogger(DBRouterAspect.class);

    private final DBRouterProperties prop;
    private final IDBRouterStrategy dbRouterStrategy;

    public DBRouterAspect(DBRouterProperties prop, IDBRouterStrategy dbRouterStrategy) {
        this.prop = prop;
        this.dbRouterStrategy = dbRouterStrategy;
    }

    @Around("@annotation(dbRouter)")
    public Object doRouter(ProceedingJoinPoint jp, DBRouter dbRouter) throws Throwable {
        String dbKey = dbRouter.key();
        if (!StringUtils.hasText(dbKey) && !StringUtils.hasText(prop.getRouterKey())) {
            throw new RuntimeException("annotation DBRouter key is null！");
        }
        dbKey = StringUtils.hasText(dbKey) ? dbKey : prop.getRouterKey();
        // 路由属性
        String dbKeyAttr = getAttrValue(dbKey, jp.getArgs());
        // 路由策略
        dbRouterStrategy.doRouter(dbKeyAttr);
        // 返回结果
        try {
            return jp.proceed();
        } finally {
            dbRouterStrategy.clear();
        }
    }

    public String getAttrValue(String attr, Object[] args) {
        if (args.length == 1) {
            Object arg = args[0];
            if (arg instanceof String) {
                return arg.toString();
            }
        }

        String filedValue = null;
        for (Object arg : args) {
            try {
                if (StringUtils.hasText(filedValue)) break;
                filedValue = String.valueOf(getValueByName(arg, attr));
            } catch (Exception e) {
                log.error("获取路由属性值失败 attr：{}", attr, e);
            }
        }
        return filedValue;
    }

    /**
     * 获取对象的特定属性值
     *
     * @param item 对象
     * @param name 属性名
     * @return 属性值
     * @author tang
     */
    private Object getValueByName(Object item, String name) {
        try {
            Field field = getFieldByName(item, name);
            if (field == null) {
                return null;
            }
            field.setAccessible(true);
            Object value = field.get(item);
            field.setAccessible(false);
            return value;
        } catch (IllegalAccessException e) {
            return null;
        }
    }

    /**
     * 根据名称获取方法，该方法同时兼顾继承类获取父类的属性
     *
     * @param item 对象
     * @param name 属性名
     * @return 该属性对应方法
     */
    private Field getFieldByName(Object item, String name) {
        try {
            Field field;
            try {
                field = item.getClass().getDeclaredField(name);
            } catch (NoSuchFieldException e) {
                field = item.getClass().getSuperclass().getDeclaredField(name);
            }
            return field;
        } catch (NoSuchFieldException e) {
            return null;
        }
    }
}
